(use-modules
 (ffi blis)
 (srfi srfi-1)
 (ice-9 match))

;; Build a data abstraction.

(define matrix-shape
  (lambda (mat)
    (array-shape mat)))

(define matrix-dimensions
  (lambda (mat)
    (array-dimensions mat)))

(define get-dims-rows
  (lambda (dims)
    (car dims)))

(define get-dims-cols
  (lambda (dims)
    (cadr dims)))

;; "For matrix multiplication, the number of columns in the first matrix must be
;; equal to the number of rows in the second matrix."
;; (https://en.wikipedia.org/wiki/Matrix_multiplication)

(define matrix-multiply!
  (lambda (mat-a mat-b mat-res)
    "Multiply mat-a and mat-b and store the result in mat-res. Return mat-res."
    (let ([alpha 1.0] [beta 1.0])
      ;; Now we can make use of the library functions.
      (gemm! BLIS_NO_TRANSPOSE BLIS_NO_TRANSPOSE
             alpha
             mat-a mat-b
             beta
             mat-res)
      mat-res)))


(define simple-matrix-multiply
  (lambda (mat-a mat-b)
    "Calculate the product of 2 matrices, automatically creating another matrix
with the correct dimensions, which the result will be writen to."
    (define mat-a-dim (matrix-dimensions mat-a))
    (define mat-b-dim (matrix-dimensions mat-b))
    (matrix-multiply! mat-a
                      mat-b
                      ;; The result of a matrix multiplication A x B has the
                      ;; shape of (rows of A, columns of B). The library demands,
                      ;; that we give a matrix to write the result to as another
                      ;; argument.
                      (make-typed-array 'f64
                                        *unspecified*
                                        (get-dims-rows mat-a-dim)
                                        (get-dims-cols mat-b-dim)))))


(display
 (simple-format
  #f "~s\n"
  (simple-matrix-multiply
   ;; First we create some arrays, which will be used as matrices. We create an
   ;; array filled with numbers of type 'f64, which is a float of 64 bits, which
   ;; is usually called a double.
   (list->typed-array 'f64
                      '(0 0)
                      '((1 1 1)
                        (1 1 1)
                        (1 1 1)
                        (1 1 1)))
   (list->typed-array 'f64
                      '(0 0)
                      '((1 1 1 1)
                        (1 1 1 1)
                        (1 1 1 1))))))
